# Edge Impulse - OpenMV Object Detection Example
import sensor, time,omv,math,lcd,tf,gc
from button import BUTTON

button=BUTTON()  #声明按键，梦飞openmv只有一个按键D8，因此直接内部指定了按键
#lcd.init(type=2,width=240,height=320)
lcd.init()
global THRESHOLD
#THRESHOLD  = [(-3, 17, -7, 13, -12, 8)]
THRESHOLD  =[(20, 77, -15, 15, -15, 15)]
nn_input_sz =64
count=0
#声明小尺寸的画布
img_to_matching=sensor.alloc_extra_fb(nn_input_sz,nn_input_sz,sensor.GRAYSCALE)
flag_lost =0
def find_max_object(objects):
    max_size=0
    max_object=(0,0,0,0,0.0)
    for object in objects:
        if object[2]*object[3]> max_size:
            max_object=object
            max_size = object[2]*object[3]
    return max_object

def auto_color_mask(img):
    global THRESHOLD
    LAB_ERROR=15
    ROI=(int(img.width()/2-5),int(img.height()/2-5),10,10)
    if button.event():
        button.key_event=0
        while not button.event():
            img = sensor.snapshot()
            img.draw_rectangle(ROI)
            lcd.display(img)
        if button.event():
            button.key_event=0
            img = sensor.snapshot()
            statistics=img.get_statistics(roi=ROI)
            img.draw_rectangle(ROI)
            color_l=statistics.l_mode()
            color_a=statistics.a_mode()
            color_b=statistics.b_mode()
            THRESHOLD=[(color_l - LAB_ERROR,color_l + LAB_ERROR,color_a - LAB_ERROR,color_a + LAB_ERROR,color_b - LAB_ERROR,color_b + LAB_ERROR)]
        print(THRESHOLD)

def number_recongnize(img,imege,blob):
    global count
    score=0
    label='\0'
    confidence=75
    score_str=""
    for obj in tf.classify(net,imege, min_scale=1.0, scale_mul=0.5, x_overlap=0.0, y_overlap=0.0):
        out = obj.output() #数字识别结果
        max_idx = out.index(max(out))#取概率最大的一个
        score = int(out[max_idx]*100) #概率计算成%比
        if score > confidence : #判断百分比是否达到阈值 #且数字结果是0-5才认为可取
            label=labels[max_idx]   #对数字进行赋值，0-9，可以用来显示和输出
            score_str = "%s:%d%% "%(label, score)
            img.draw_string(blob.x(), blob.y()-20, score_str,scale=1.5, color=(255,255, 0))
            count=count+1
        else:
            score_str = "??:??%"
            count=0
    if count>=3 : #多次识别都是同一个数字
        count=0 #正确次数清空，方便下一次再重新统计
        if max_idx>0 and max_idx<=9 :
            #screen.display(imege)
            #img.draw_image(imege,0, 0)
            print("number:",label)#打印数字
        return label
    else:
        return None

#单个数字识别
def Mnist_number(img,blob):
    scale=1.0
    if (blob) and (blob[3]<120) and (blob[2]<120) :
        #按坐标和比例提取出色块，注意坐标长宽都向外扩大4个像素，避免图像不全
        error_s=int((blob.h()-blob.w())/2)  #目标位置，由于模板是正方形，这里做了目标扩展
        roi1=(blob.x()-error_s-1,blob.y()-1,blob.h()+5,blob.h()+5)
        scale=nn_input_sz/(blob.h())  #缩放比例系数
        img_to_matching.clear()
        img_to_matching.draw_image(img,0,0,x_scale=scale,y_scale=scale,roi=roi1)#将roi画到模板画布上
        #img_to_matching.laplacian(1)  #通过拉普拉斯变换，突出色彩分界线（数值越大效果越好，但越慢。所以用最小值，再提高画面亮度）
        #img_to_matching.gamma_corr(gamma=1.2,contrast=25) #提高画面伽马值、对比度、亮度
        img_to_matching.mode(1)
        hist=img_to_matching.get_histogram()
        thread=hist.get_threshold()
        img_to_matching.binary([(0,thread[0])]) #这个阈值适合白字黑底
        return number_recongnize(img,img_to_matching,blob)
    else :
        return None

def find_number_control(img,mode=1):
    global number_lost,THRESHOLD
    global pan_angle,tilt_angle
    max_object=None
    blobs=img.find_blobs(THRESHOLD,area_threshold=150)  #寻找色块
    if len(blobs):
        max_object=find_max_object(blobs)
        Mnist_number(img,blob=max_object)
        img.draw_rectangle(max_object.rect(),color=(255,0,0))


labels=['0','1','2','3','4','5','6','7','8','9']
try:
    # load the model, alloc the model file on the heap if we have at least 64K free after loading
    #net = tf.load("mnist_valid_f.tflite",load_to_fb=True)
    labels,net = tf.load_builtin_model("mnist")
except Exception as e:
    raise Exception('Failed to load "mnist", did you copy the .tflite and labels.txt file onto the mass-storage device? (' + str(e) + ')')
print(net)

##############################摄像头初始化部分#####################
sensor.reset()                         # Reset and initialize the sensor.
sensor.set_pixformat(sensor.GRAYSCALE)    # Set pixel format to RGB565 (or GRAYSCALE)
sensor.set_framesize(sensor.V240X240)      # Set frame size to QVGA (320x240)
sensor.skip_frames(time=1000) # Let new settings take affect.
#sensor.set_vflip(True)
#sensor.set_hmirror(True)
clock = time.clock() # Tracks FPS. 设置一个定时器用来计算帧率
while True:
    clock.tick()
    img = sensor.snapshot()
    find_number_control(img)
    lcd.display(img)
    auto_color_mask(img)
    gc.collect()
    print(clock.fps(), "fps", end="\n\n")
    if button.state():
        click_timer=time.ticks_ms()          #开始计时
        while button.state():  pass       #等待按键抬起
        if time.ticks_ms()-click_timer>2000: #按键时长超过2s
            sensor.dealloc_extra_fb()
            del net
            break                         #循环退出，回到主界面
    else :
        click_timer=time.ticks_ms()#计时更新

